/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Tests for device-wide GEMM interface
*/

#include <iostream>
#include <fstream>
#include <sstream>

#include "../../common/cutlass_unit_test.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/core_io.h"

#include "testbed.h"

namespace test {
namespace gemm {
namespace device {

/////////////////////////////////////////////////////////////////////////////////////////////////

//
// List of Gemm internal paramters this testbed supports user verification
//
enum class ParameterID {

    // Threadblock-level parameters
    kSmemASize,
    kSmemBSize,

    // Warp-level parameters
    kWarpFragmentASize,
    kWarpFragmentBSize,
    kWarpFragmentCSize,
    kInvalid
};

struct Reference {
    ParameterID parameter_id;

    union {
        int value;

        struct {
            int m, n, k;
        } gemm_shape;

        struct {
            int row, column;
        } matrix_shape;
    };

    std::string error_msg;

    Reference(ParameterID parameter_id_, int value_ = -1,
              std::string const& error_msg_ = "")
            : parameter_id(parameter_id_),
              value(value_),
              error_msg(error_msg_) {}
};

template <typename Gemm>
struct TestbedSanity {
    //
    // Type definitions (All Gemm types top down)
    //

    // Unpacking Gemm types in the following order
    // Kernel-level > Threadblock-level > Warp-level > Instruction-level

    // kernel-level cutlass Gemm
    using GemmKernel = typename Gemm::GemmKernel;

    //
    // Threadblock-level gemm types
    //
    using MmaThreadBlock = typename GemmKernel::Mma;

    // Threadblock-level gemm shape covering one stage
    using ThreadblockShape = typename MmaThreadBlock::Shape;

    // Shared memory size covering all stages
    using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA;
    using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA;
    using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB;
    using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB;

    /// Number of stages
    static int const kStages = MmaThreadBlock::Base::kStages;

    /// Number of warp-level GEMM oeprations
    static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations;

    //
    // Warp-level gemm types
    //

    // Warp-level gemm operator
    using MmaWarp = typename MmaThreadBlock::Operator;

    // Warp-level gemm shape covering all kgroups
    using WarpShape = typename MmaWarp::Shape;

    // Warp-level framents holding operands A & B operand and destination C
    using WarpFragmentA = typename MmaWarp::FragmentA;
    using WarpFragmentB = typename MmaWarp::FragmentB;
    using WarpFragmentC = typename MmaWarp::FragmentC;

    //
    // Instruction-level gemm types
    //

    // Instruction-level gemm operator
    using MmaInstruction = typename MmaWarp::Policy::Operator;

    // Instruction shape
    using InstructionShape = typename MmaInstruction::Shape;

    // Instruction-level framents holding operands A & B operand and destination
    // C
    using InstructionFragmentA = typename MmaInstruction::FragmentA;
    using InstructionFragmentB = typename MmaInstruction::FragmentB;
    using InstructionFragmentC = typename MmaInstruction::FragmentC;

    //
    // Testbed types
    //

    // Vector of values holding user provided reference
    using ReferenceVector = std::vector<Reference>;

    //
    // Data members
    //
    ReferenceVector references;

    //
    // Methods
    //

    TestbedSanity(ReferenceVector const& references_ = ReferenceVector())
            : references(references_) {}

    // verify all parameter in ReferenceVector
    bool verify() {
        for (auto ref : references)
            verify_parameter(ref);
        return true;
    }

    // verify parameter of type Reference
    void verify_parameter(Reference const& ref) {
        switch (ref.parameter_id) {
            case ParameterID::kWarpFragmentASize:
                EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this;
                break;
            case ParameterID::kWarpFragmentBSize:
                EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this;
                break;
            case ParameterID::kWarpFragmentCSize:
                EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this;
                break;
        }
    }
};

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//                             Overload output operators for TesbedSanity<Gemm>
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Gemm>
std::ostream& operator<<(std::ostream& out, TestbedSanity<Gemm> const& test) {
    out << "Gemm internal parameters" << std::endl
        << "  Threadblock-level parameters:" << std::endl
        << "     ThreadblockShape = " <<
            typename TestbedSanity<Gemm>::ThreadblockShape() << std::endl
        << "     kStages = " << TestbedSanity<Gemm>::kStages << std::endl
        << "     kWarpGemmIterations = "
        << TestbedSanity<Gemm>::kWarpGemmIterations << std::endl
        << "  Shared memory sizes:" << std::endl
        << "    SmemPaddingA = " << typename TestbedSanity<Gemm>::SmemPaddingA()
        << std::endl
        << "    SmemPaddingB = " << typename TestbedSanity<Gemm>::SmemPaddingB()
        << std::endl
        << "      SmemShapeA = " << typename TestbedSanity<Gemm>::SmemShapeA()
        << std::endl
        << "      SmemShapeB = " << typename TestbedSanity<Gemm>::SmemShapeB()
        << std::endl
        << "  Warp-level parameters" << std::endl
        << "    WarpShape = " << typename TestbedSanity<Gemm>::WarpShape()
        << std::endl
        << "    Fragment sizes:" << std::endl
        << "      WarpFragmentA::kElements = "
        << TestbedSanity<Gemm>::WarpFragmentA::kElements << std::endl
        << "      WarpFragmentB::kElements = "
        << TestbedSanity<Gemm>::WarpFragmentB::kElements << std::endl
        << "      WarpFragmentC::kElements = "
        << TestbedSanity<Gemm>::WarpFragmentC::kElements << std::endl
        << "  Instruction-level parameters" << std::endl
        << "    InstructionShape = " <<
            typename TestbedSanity<Gemm>::InstructionShape() << std::endl
        << "    Fragment sizes:" << std::endl
        << "      InstructionFragmentA::kElements = "
        << TestbedSanity<Gemm>::InstructionFragmentA::kElements << std::endl
        << "      InstructionFragmentB::kElements = "
        << TestbedSanity<Gemm>::InstructionFragmentB::kElements << std::endl
        << "      InstructionFragmentC::kElements = "
        << TestbedSanity<Gemm>::InstructionFragmentC::kElements << std::endl;

    return out;
}

}  // namespace device
}  // namespace gemm
}  // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////
